import asyncio
import time
import socket

from datetime import datetime

from pylog.pylogger import PyLogger

from py_pli.pylib import VUnits

import config_enum.excitationlight_selector_enum as els_config
import config_enum.filter_module_slider_enum as fms_config

from virtualunits.HAL import HAL
from virtualunits.vu_excitation_light_selector import VUExcitationLightSelector
from virtualunits.vu_filter_module_slider import VUFilterModuleSlider
from virtualunits.vu_detector_aperture_slider import VUDetectorApertureSlider
from virtualunits.vu_measurement_unit import VUMeasurementUnit
from virtualunits.meas_seq_generator import meas_seq_generator
from virtualunits.meas_seq_generator import TriggerSignal
from virtualunits.meas_seq_generator import OutputSignal

from urpc_enum.measurementparameter import MeasurementParameter


hal_unit: HAL = VUnits.instance.hal
els_unit: VUExcitationLightSelector = VUnits.instance.hal.excitationLightSelector
fms_unit: VUFilterModuleSlider = VUnits.instance.hal.filterModuleSlider
as1_unit: VUDetectorApertureSlider = VUnits.instance.hal.detectorApertureSlider1
as2_unit: VUDetectorApertureSlider = VUnits.instance.hal.detectorApertureSlider2
meas_unit: VUMeasurementUnit = VUnits.instance.hal.measurementUnit


async def init():
    await hal_unit.StartupHardware()
    await hal_unit.InitializeDevice()
    await hal_unit.HomeMovers()
    await hal_unit.TurnLedsOff()
    await hal_unit.TurnOn_PMT_HV()
    await asyncio.sleep(1)

    await els_unit.GotoPosition(els_config.Positions.Flash2)
    await fms_unit.GotoPosition(fms_config.Positions.FixMirrorPosition_c)


async def flash_excitation_scan(pmt: str, as_pos, start_us=0.0, window_us=1.0, window_count=100, flash_count=1, flash_power=None, high_power=None):
    if (start_us < 0.0) or (start_us > 671088.64):
        raise ValueError(f"start_us must be in the range [0.0, 671088.64] us")
    if (window_us < 0.1) or (window_us > 671088.64):
        raise ValueError(f"step_us must be in the range [0.1, 671088.64] us")
    if (window_count < 1) or (window_count > 4096):
        raise ValueError(f"window_count must be in the range [1, 4096]")
    if (flash_count < 1) or (flash_count > 65536):
        raise ValueError(f"flash_count must be in the range [1, 65536]")
    if (flash_power is not None) and ((flash_power < 0.0) or (flash_power > 1.0)):
        raise ValueError(f"flash_power must be in the range [0.0, 1.0]")
    if (high_power is not None) and (high_power != 0) and (high_power != 1):
        raise ValueError(f"high_power must be 0 or 1")

    if (pmt == 'pmt1'):
        await as1_unit.Move(as_pos)
        await as2_unit.Move(0)
        channel = 0
    elif (pmt == 'pmt2'):
        await as1_unit.Move(0)
        await as2_unit.Move(as_pos)
        channel = 1
    else:
        raise ValueError(f"pmt must be pmt1 or pmt2")

    if flash_power is not None:
        await meas_unit.endpoint.SetParameter(MeasurementParameter.FlashLampPower, flash_power, timeout=1)

    if high_power is not None:
        await meas_unit.endpoint.SetParameter(MeasurementParameter.FlashLampHighPowerEnable, high_power, timeout=1)
    else:
        high_power = (await meas_unit.endpoint.GetParameter(MeasurementParameter.FlashLampHighPowerEnable, timeout=1))[0]
    
    await asyncio.sleep(0.1)

    if not high_power:
        arming = 50000   #  500 us arming time
    else:
        arming = 120000  # 1200 us arming time
    delay = round(start_us * 100)
    window = round(window_us * 100)
    # 250 Hz flash frequency:
    loop_delay = 400000 - arming - delay - window * window_count

    op_id = 'flash_excitation_scan'
    seq_gen = meas_seq_generator()
    # PMT1 and PMT2 high voltage gate on
    seq_gen.SetSignals(OutputSignal.HVGatePMT1 | OutputSignal.HVGatePMT2)
    # Clear the result buffer
    seq_gen.SetAddrReg(relative=False, dataNotAddrSrc=False, sign=False, stackNotRegSrc=False, srcReg=0, dstReg=0, addr=0)
    seq_gen.Loop(window_count)
    seq_gen.ClearResultBuffer(relative=True, dword=False, addrReg=0, addr=0)
    seq_gen.SetAddrReg(relative=True, dataNotAddrSrc=False, sign=False, stackNotRegSrc=False, srcReg=0, dstReg=0, addr=1)
    seq_gen.LoopEnd()

    seq_gen.Loop(flash_count)
    # Arm the flash lamp
    seq_gen.TimerWaitAndRestart(arming)
    seq_gen.SetSignals(OutputSignal.Flash)
    # Trigger the flash
    seq_gen.TimerWaitAndRestart(delay)
    seq_gen.ResetSignals(OutputSignal.Flash)
    # Start the scan with PMT
    seq_gen.TimerWaitAndRestart(window)
    seq_gen.PulseCounterControl(channel, cumulative=False, resetCounter=True, resetPresetCounter=True, correctionOn=False)
    seq_gen.SetAddrReg(relative=False, dataNotAddrSrc=False, sign=False, stackNotRegSrc=False, srcReg=0, dstReg=0, addr=0)
    # Loop to measure each window and save the result
    seq_gen.Loop(window_count)
    seq_gen.TimerWaitAndRestart(window)
    seq_gen.PulseCounterControl(channel, cumulative=False, resetCounter=False, resetPresetCounter=True, correctionOn=True)
    seq_gen.GetPulseCounterResult(channel, relative=True, resetCounter=True, cumulative=True, dword=False, addrPos=0, resultPos=0)
    seq_gen.SetAddrReg(relative=True, dataNotAddrSrc=False, sign=False, stackNotRegSrc=False, srcReg=0, dstReg=0, addr=1)
    seq_gen.LoopEnd()
    if (loop_delay > 0):
        seq_gen.TimerWaitAndRestart(loop_delay)
    seq_gen.LoopEnd()
    seq_gen.Stop(0)
    
    meas_unit.ClearOperations()
    meas_unit.resultAddresses[op_id] = range(0, window_count)
    await meas_unit.LoadTriggerSequence(op_id, seq_gen.currSequence)
    await meas_unit.ExecuteMeasurement(op_id)
    results = await meas_unit.ReadMeasurementValues(op_id)

    results = [result / flash_count for result in results]

    return results


async def flash_excitation_dual_scan(as1_pos, as2_pos, start_us=0.0, window_us=1.0, window_count=100, flash_count=1, flash_power=None, high_power=None):
    if (start_us < 0.0) or (start_us > 671088.64):
        raise ValueError(f"start_us must be in the range [0.0, 671088.64] us")
    if (window_us < 0.1) or (window_us > 671088.64):
        raise ValueError(f"step_us must be in the range [0.1, 671088.64] us")
    if (window_count < 1) or (window_count > 2048):
        raise ValueError(f"window_count must be in the range [1, 2048]")
    if (flash_count < 1) or (flash_count > 65536):
        raise ValueError(f"flash_count must be in the range [1, 65536]")
    if (flash_power is not None) and ((flash_power < 0.0) or (flash_power > 1.0)):
        raise ValueError(f"flash_power must be in the range [0.0, 1.0]")
    if (high_power is not None) and (high_power != 0) and (high_power != 1):
        raise ValueError(f"high_power must be 0 or 1")

    await as1_unit.Move(as1_pos)
    await as2_unit.Move(as2_pos)

    if flash_power is not None:
        await meas_unit.endpoint.SetParameter(MeasurementParameter.FlashLampPower, flash_power, timeout=1)

    if high_power is not None:
        await meas_unit.endpoint.SetParameter(MeasurementParameter.FlashLampHighPowerEnable, high_power, timeout=1)
    else:
        high_power = (await meas_unit.endpoint.GetParameter(MeasurementParameter.FlashLampHighPowerEnable, timeout=1))[0]
    
    await asyncio.sleep(0.1)

    if not high_power:
        arming = 50000   #  500 us arming time
    else:
        arming = 120000  # 1200 us arming time
    delay = round(start_us * 100)
    window = round(window_us * 100)
    # 250 Hz flash frequency:
    loop_delay = 400000 - arming - delay - window * window_count

    op_id = 'flash_excitation_scan'
    seq_gen = meas_seq_generator()
    # PMT1 and PMT2 high voltage gate on
    seq_gen.SetSignals(OutputSignal.HVGatePMT1 | OutputSignal.HVGatePMT2)
    # Clear the result buffer
    seq_gen.SetAddrReg(relative=False, dataNotAddrSrc=False, sign=False, stackNotRegSrc=False, srcReg=0, dstReg=0, addr=0)
    seq_gen.Loop(window_count * 2)
    seq_gen.ClearResultBuffer(relative=True, dword=False, addrReg=0, addr=0)
    seq_gen.SetAddrReg(relative=True, dataNotAddrSrc=False, sign=False, stackNotRegSrc=False, srcReg=0, dstReg=0, addr=1)
    seq_gen.LoopEnd()

    seq_gen.Loop(flash_count)
    # Arm the flash lamp
    seq_gen.TimerWaitAndRestart(arming)
    seq_gen.SetSignals(OutputSignal.Flash)
    # Trigger the flash
    seq_gen.TimerWaitAndRestart(delay)
    seq_gen.ResetSignals(OutputSignal.Flash)
    # Start the scan with PMT1 and PMT2
    seq_gen.TimerWaitAndRestart(window)
    seq_gen.PulseCounterControl(channel=0, cumulative=False, resetCounter=True, resetPresetCounter=True, correctionOn=False)
    seq_gen.PulseCounterControl(channel=1, cumulative=False, resetCounter=True, resetPresetCounter=True, correctionOn=False)
    seq_gen.SetAddrReg(relative=False, dataNotAddrSrc=False, sign=False, stackNotRegSrc=False, srcReg=0, dstReg=0, addr=0)
    # Loop to measure each window and save the result
    seq_gen.Loop(window_count)
    seq_gen.TimerWaitAndRestart(window)
    seq_gen.PulseCounterControl(channel=0, cumulative=False, resetCounter=False, resetPresetCounter=True, correctionOn=True)
    seq_gen.PulseCounterControl(channel=1, cumulative=False, resetCounter=False, resetPresetCounter=True, correctionOn=True)
    seq_gen.GetPulseCounterResult(channel=0, relative=True, resetCounter=True, cumulative=True, dword=False, addrPos=0, resultPos=0)
    seq_gen.GetPulseCounterResult(channel=1, relative=True, resetCounter=True, cumulative=True, dword=False, addrPos=0, resultPos=1)
    seq_gen.SetAddrReg(relative=True, dataNotAddrSrc=False, sign=False, stackNotRegSrc=False, srcReg=0, dstReg=0, addr=2)
    seq_gen.LoopEnd()
    if (loop_delay > 0):
        seq_gen.TimerWaitAndRestart(loop_delay)
    seq_gen.LoopEnd()
    seq_gen.Stop(0)
    
    meas_unit.ClearOperations()
    meas_unit.resultAddresses[op_id] = range(0, (window_count * 2))
    await meas_unit.LoadTriggerSequence(op_id, seq_gen.currSequence)
    await meas_unit.ExecuteMeasurement(op_id)
    results = await meas_unit.ReadMeasurementValues(op_id)
    
    results = [result / flash_count for result in results]

    return results
    

async def flash_as_scan(pmt: str, pos_start=1.0, pos_stop=1.2, pos_step=0.01, flash_count=1, flash_power=None, high_power=None):

    window_us = 1.0
    window_count = 100

    await init()
    
    if flash_power is None:
        flash_power = (await meas_unit.endpoint.GetParameter(MeasurementParameter.FlashLampPower, timeout=1))[0]
    if high_power is None:
        high_power = (await meas_unit.endpoint.GetParameter(MeasurementParameter.FlashLampHighPowerEnable, timeout=1))[0]
    
    instrument = socket.gethostname()
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    with open(f"flash_as_scan__{instrument}_{timestamp}.csv", 'w') as file:
        file.write(f"flash_as_scan(pmt={pmt}, pos_start={pos_start:.3f}, pos_stop={pos_stop:.3f}, pos_step={pos_step:.3f}, flash_count={flash_count}, flash_power={flash_power:.3f}, high_power={high_power}) on {instrument} started at {timestamp}\n")
        file.write(f"time [us]")
        for i in range(window_count):
            file.write(f" ; {(window_us * (i + 1)):5.1f}")
        file.write(f"\n")
        pos_range = [pos / 1e6 for pos in range(round(pos_start * 1e6), round(pos_stop * 1e6 + 1), round(pos_step * 1e6))]
        for as_pos in pos_range:
            results = await flash_excitation_scan(pmt=pmt, as_pos=as_pos, window_us=window_us, window_count=window_count, flash_count=flash_count, flash_power=flash_power, high_power=high_power)
            file.write(f"AS={as_pos:.3f} ")
            for i in range(window_count):
                file.write(f" ; {results[i]:5.1f}")
            file.write(f"\n")

    return f"flash_as_scan() done"
    

async def flash_as_dual_scan(pos_start=1.0, pos_stop=1.2, pos_step=0.01, flash_count=1, flash_power=None, high_power=None):

    window_us = 1.0
    window_count = 100

    await init()

    if flash_power is None:
        flash_power = (await meas_unit.endpoint.GetParameter(MeasurementParameter.FlashLampPower, timeout=1))[0]
    if high_power is None:
        high_power = (await meas_unit.endpoint.GetParameter(MeasurementParameter.FlashLampHighPowerEnable, timeout=1))[0]

    instrument = socket.gethostname()
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    with open(f"flash_as_dual_scan__{instrument}_{timestamp}.csv", 'w') as file:
        file.write(f"flash_as_dual_scan(pos_start={pos_start:.3f}, pos_stop={pos_stop:.3f}, pos_step={pos_step:.3f}, flash_count={flash_count}, flash_power={flash_power:.3f}, high_power={high_power}) on {instrument} started at {timestamp}\n")
        file.write(f"time [us]")
        for i in range(window_count):
            file.write(f" ; {(window_us * (i + 1)):5.1f}")
        file.write(f"\n")
        pos_range = [pos / 1e6 for pos in range(round(pos_start * 1e6), round(pos_stop * 1e6 + 1), round(pos_step * 1e6))]
        for as_pos in pos_range:
            results = await flash_excitation_dual_scan(as1_pos=as_pos, as2_pos=as_pos, window_us=window_us, window_count=window_count, flash_count=flash_count, flash_power=flash_power, high_power=high_power)
            file.write(f"AS1={as_pos:.3f}")
            for i in range(0, (window_count * 2), 2):
                file.write(f" ; {results[i]:5.1f}")
            file.write(f"\n")
            file.write(f"AS2={as_pos:.3f}")
            for i in range(1, (window_count * 2), 2):
                file.write(f" ; {results[i]:5.1f}")
            file.write(f"\n")

    return f"flash_as_dual_scan() done"
    

async def flash_scan(pmt: str, as_pos, iterations=10, flash_count=1, flash_power=None, high_power=None):

    window_us = 1.0
    window_count = 100

    await init()
    
    if flash_power is None:
        flash_power = (await meas_unit.endpoint.GetParameter(MeasurementParameter.FlashLampPower, timeout=1))[0]
    if high_power is None:
        high_power = (await meas_unit.endpoint.GetParameter(MeasurementParameter.FlashLampHighPowerEnable, timeout=1))[0]

    instrument = socket.gethostname()
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    with open(f"flash_scan__{instrument}_{timestamp}.csv", 'w') as file:
        file.write(f"flash_scan(pmt={pmt}, as_pos={as_pos:.3f}, iterations={iterations}, flash_count={flash_count}, flash_power={flash_power:.3f}, high_power={high_power}) on {instrument} started at {timestamp}\n")
        file.write(f"time [us]")
        for i in range(window_count):
            file.write(f" ; {(window_us * (i + 1)):5.1f}")
        file.write(f"\n")
        for i in range(iterations):
            results = await flash_excitation_scan(pmt=pmt, as_pos=as_pos, window_us=window_us, window_count=window_count, flash_count=flash_count, flash_power=flash_power, high_power=high_power)
            file.write(f"SCAN #{(i+1):02d} ")
            for j in range(window_count):
                file.write(f" ; {results[j]:5.1f}")
            file.write(f"\n")

    return f"flash_scan() done"
    

async def flash_dual_scan(as1_pos, as2_pos, iterations=10, flash_count=1, flash_power=None, high_power=None):

    window_us = 1.0
    window_count = 100

    await init()
    
    if flash_power is None:
        flash_power = (await meas_unit.endpoint.GetParameter(MeasurementParameter.FlashLampPower, timeout=1))[0]
    if high_power is None:
        high_power = (await meas_unit.endpoint.GetParameter(MeasurementParameter.FlashLampHighPowerEnable, timeout=1))[0]

    instrument = socket.gethostname()
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    with open(f"flash_dual_scan__{instrument}_{timestamp}.csv", 'w') as file:
        file.write(f"flash_dual_scan(as1_pos={as1_pos:.3f}, as2_pos={as2_pos:.3f}, iterations={iterations}, flash_count={flash_count}, flash_power={flash_power:.3f}, high_power={high_power}) on {instrument} started at {timestamp}\n")
        file.write(f"time [us]")
        for i in range(window_count):
            file.write(f" ; {(window_us * (i + 1)):5.1f}")
        file.write(f"\n")
        for i in range(iterations):
            results = await flash_excitation_dual_scan(as1_pos=as1_pos, as2_pos=as2_pos, window_us=window_us, window_count=window_count, flash_count=flash_count, flash_power=flash_power, high_power=high_power)
            file.write(f"PMT1 #{(i+1):02d} ")
            for j in range(0, (window_count * 2), 2):
                file.write(f" ; {results[j]:5.1f}")
            file.write(f"\n")
            file.write(f"PMT2 #{(i+1):02d} ")
            for j in range(1, (window_count * 2), 2):
                file.write(f" ; {results[j]:5.1f}")
            file.write(f"\n")

    return f"flash_dual_scan() done"
    

async def flash_scan_graph(pmt: str, as_pos, flash_count=1, duration_s=30.0, delay_s=1.0, maximum=60):

    window_us = 1.0
    window_count = 100

    await init()

    start = time.time()
    while ((time.time() - start) <= duration_s):

        results = await flash_excitation_scan(pmt=pmt, as_pos=as_pos, window_us=window_us, window_count=window_count, flash_count=flash_count)
        
        print(f"{pmt.upper()}:\n")
        print_graph(results, maximum)

        await asyncio.sleep(delay_s)

    return f"flash_scan_graph() done"
    

async def flash_dual_scan_graph(as1_pos, as2_pos, flash_count=1, duration_s=30.0, delay_s=1.0, maximum=60):

    window_us = 1.0
    window_count = 100

    await init()

    start = time.time()
    while ((time.time() - start) <= duration_s):

        results = await flash_excitation_dual_scan(as1_pos=as1_pos, as2_pos=as2_pos, window_us=window_us, window_count=window_count, flash_count=flash_count)

        results_pmt1 = results[0::2]
        results_pmt2 = results[1::2]

        print(f"PMT1:\n")
        print_graph(results_pmt1, maximum)

        print(f"PMT2:\n")
        print_graph(results_pmt2, maximum)

        await asyncio.sleep(delay_s)

    return f"flash_dual_scan_graph() done"


def print_graph(results, maximum):

    background = '\033[90m' + '-' + '\033[0m'
    line = '0'

    point = [round(result / maximum * 10) for result in results]
        
    chart = [[background]*len(point) for i in range(11)]
    for x in range(len(point)):
        for y in range(10, (10 - min(point[x], 10) - 1), -1):
            chart[y][x] = line

    graph = ''
    for y in range(11):
        graph += f"{round(maximum * (1 - y / 10)):2d} " + ''.join(chart[y]) + '\n'
    
    print(graph)

